import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import datasets, layers, models
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange
from tqdm.keras import TqdmCallback
import os
import math
The training data should be prepared as described by Dong et al. in Image Super-Resolution Using Deep Convolutional Networks:
In the training phase, the ground truth images $X_i$ are prepared as $f_{sub}\times f_{sub}\times c$-pixel sub-images randomly cropped from the training images. By “sub-images” we mean these samples are treated as small “images” ratherthan “patches”, in the sense that “patches” are overlapping and require some averaging as post-processing but “sub-images” need not. To synthesize the low-resolution samples $Y_i$, we blur a sub-image by a Gaussian kernel, sub-sample it by the upscaling factor, and upscale it by the same factor via bicubic interpolation.
The paper we're attempting to replicate did not train their model on color images, so we will convert the data set to greyscale before cropping.
# Dataset parameters
IMAGE_DIR = r'./SR_training_datasets/T91/'
SUBIMAGE_SIZE = [33, 33, 1]
SUBIMAGE_QUANTITY = 289
GAUSSIAN_SHAPE = [3, 3]
GAUSSIAN_SIGMA = 1.0
SCALING_FACTOR = 2
BATCH_SIZE = 128
image_paths = [os.path.join(IMAGE_DIR, path) for path in os.listdir(IMAGE_DIR)]
plt.figure(figsize=(20, 40))
for i, path in enumerate(image_paths):
image = plt.imread(path)
plt.subplot(20, 5, i + 1)
plt.title('Image %d' % i)
plt.axis('off')
plt.tight_layout()
plt.imshow(image)
plt.show()
# Create high-res subimages
def gen_subimg(path):
in_image = tf.image.decode_png(tf.io.read_file(path))
# Convert to grayscale and normalize to [0, 1.0]
image = tf.image.rgb_to_grayscale(in_image)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return [tf.image.random_crop(image, SUBIMAGE_SIZE, seed=i) for i in range(SUBIMAGE_QUANTITY)]
subimgs = list(map(gen_subimg, image_paths))
def plot_subimgs(images):
subplot_dim = math.ceil(math.sqrt(len(images)))
plt.figure(figsize=(20, 10))
for index,subimg in enumerate(images):
plt.subplot(math.ceil(subplot_dim / 2), subplot_dim * 2, index + 1)
plt.axis('off')
plt.imshow(np.asarray(subimg[:, :, 0], dtype='float'), cmap='Greys_r')
plt.show()
# Plot all subimages
# list(map(plot_subimgs, subimgs));
# Plot subimages for one image
plot_subimgs(subimgs[0])
def downsample_subimgs(images):
result = []
for subimg in images:
highres = subimg
blurry = tfa.image.gaussian_filter2d(subimg, GAUSSIAN_SHAPE, GAUSSIAN_SIGMA)
downsampled = tf.image.resize(blurry, np.divide(SUBIMAGE_SIZE[0:2], SCALING_FACTOR))
final = tf.image.resize(downsampled, SUBIMAGE_SIZE[0:2], method='bicubic')
result.append(final)
return result
lr_subimgs = list(map(downsample_subimgs, subimgs))
# Plot all subimages
# list(map(plot_subimgs, subimgs));
# Plot subimages for one image
plot_subimgs(lr_subimgs[0])
# Crop subimages to be the same size as the output layer
subimgs = [tf.image.crop_to_bounding_box(subimg, 6, 6, 21, 21) for subimg in subimgs]
# Combine the original high-res subimages with the low-res derivatives to complete the dataset
dataset = tf.data.Dataset.from_tensor_slices(((lr_subimgs, subimgs)))
# Model parameters
n_1 = 64
n_2 = 32
f_1 = 9 # Filter 1 size
f_2 = 1 # Filter 2 size
f_3 = 5 # Filter 3 size
c = SUBIMAGE_SIZE[2] # Number of color channels
3.1.1 Patch extraction and representation¶
A popular strategy in image restoration (e.g.,[1]) is to densely extract patches and then represent them by a set of pre-trained bases such as PCA, DCT, Haar, etc. This is equivalent to convolving the image by a set of filters, each of which is a basis. In our formulation, we involve the optimization of these bases into the optimization of the network. Formally, our first layer is expressed as an operation $F_1$:
$$ F_1(\vec Y) = \max(0,W_1∗\vec Y+B_1) $$where $W_1$ and $B_1$ represent the filters and biases respectively, and ’$∗$’ denotes the convolution operation. Here, $W_1$ corresponds to $n_1$ filters of support $c\times f_1\times f_1$, where $c$ is the number of channels in the input image, $f_1$ is the spatial size of a filter. Intuitively, $W_1$ applies $n_1$ convolutions on the image, and each convolution has a kernel size $c×f1×f1$. The output is composed of $n_1$ feature maps. $B_1$ is an $n_1$-dimensional vector, whose each element is associated with a filter. We apply the Rectified Linear Unit (ReLU, $\max(0,x)$) [33] on the filter responses.
3.1.2 Non-linear mapping¶
The first layer extracts an $n_1$-dimensional feature for each patch. In the second operation, we map each of these $n_1$-dimensional vectors into an $n_2$-dimensional one. This is equivalent to applying $n_2$ filters which have a trivial spatial support $1\times1$. This interpretation is only valid for $1\times1$ filters. But it is easy to generalize to larger filters like $3\times3$ or $5\times5$. In that case, the non-linear mapping is not on a patch of the input image; instead, it is on a $3\times3$ or $5\times5$ “patch” of the feature map. The operation of the second layer is:
$$ F_2(\vec Y) = \max(0,W_2∗F_1(\vec Y) +B_2) $$Here $W_2$ contains $n_2$ filters of size $n_1\times f_2\times f_2$, and B_2 is $n_2$-dimensional. Each of the output $n_2$-dimensional vectors is conceptually a representation of a high-resolution patch that will be used for reconstruction. It is possible to add more convolutional layers to increase the non-linearity. But this can increase the complexity of the model ($n_2\times f_2\times f_2\times n_2$ parameters for one layer), and thus demands more training time. We will explore deeper structures by introducing additional non-linear mapping layers in Section 4.3.3.
3.1.3 Reconstruction¶
In the traditional methods, the predicted overlapping high-resolution patches are often averaged to produce the final full image. The averaging can be considered as a pre-defined filter on a set of feature maps (where each position is the “flattened” vector form of a high-resolution patch). Motivated by this, we define a convolutional layer to produce the final high-resolution image:
$$ F(\vec Y) = W_3∗F_2(\vec Y) +B_3 $$Here $W_3$ corresponds to $c$ filters of a size $n_2\times f_3\times f_3$, and $B_3$ is a $c$-dimensional vector. If the representations of the high-resolution patches are in the image domain (i.e., we can simply reshape each representation to form the patch), we expect that the filters act like an averaging filter; if the representations of the high-resolution patches are in some other domains (e.g., coefficients in terms of some bases), we expect that $W_3$ behaves like first projecting the coefficients onto the image domain and then averaging. In either way, $W_3$ is a set of linear filters.
# Not currently used
def central_loss(y_true, y_pred):
#resize true and pred to be
print(y_true.shape, y_pred.shape)
y_true_resize = tf.slice(y_true, [0, y_true.shape[1] // 2, y_true.shape[2] // 2, 0], [y_true.shape[0], 1, 1, y_true.shape[3]])
y_pred_resize = tf.slice(y_pred, [0, y_pred.shape[1] // 2, y_pred.shape[2] // 2, 0], [y_pred.shape[0], 1, 1, y_pred.shape[3]])
print(y_true_resize.shape, y_pred_resize.shape)
return tf.keras.losses.MSE(y_true_resize, y_pred_resize)
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
history = None
with strategy.scope():
model = models.Sequential(name='Greyscale_RELU')
# F_1
model.add(layers.Conv2D(n_1, f_1, activation='relu', use_bias=True, input_shape=SUBIMAGE_SIZE))
# F_2
model.add(layers.Conv2D(n_2, f_2, activation='relu', use_bias=True))
# F_3
model.add(layers.Conv2D(n_2, f_3, use_bias=True))
model.summary()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
checkpointCallback = tf.keras.callbacks.ModelCheckpoint('Greyscale_RELU.hdf5', monitor="mean_squared_error", verbose=0, save_best_only=True, save_weights_only=False, mode="auto", save_freq="epoch")
model.compile(optimizer=optimizer, loss=tf.keras.losses.MeanSquaredError(), metrics=['mean_squared_error'])
history = model.fit(dataset, epochs=50000, verbose=0, callbacks=[TqdmCallback(verbose=1), checkpointCallback])
psnr = 20 * math.log(1, 10) - 10 * np.log10(history.history['mean_squared_error'])
plt.figure(figsize=(6, 3), dpi=200)
# plt.plot(history.history['mean_squared_error'], label='mean_squared_error')
plt.plot(psnr, label='PSNR')
plt.xlabel('Training Epoch')
plt.ylabel('Average test PSNR (dB)')
plt.ylim(25, 32.5)
# plt.yscale('log')
plt.legend(loc='upper right')
plt.show()
fig, ax1 = plt.subplots(figsize=(6, 3), dpi=200)
color = 'C0'
ax1.set_xlabel('Training epoch')
ax1.set_ylabel('Average test PSNR (dB)', color=color)
ax1.plot(psnr, color=color)
ax1.set_ylim(30, 32)
ax1.tick_params(axis='y', labelcolor=color)
ax2 = ax1.twinx()
color = 'C1'
ax2.set_ylabel('Mean Squared Error', color=color)
ax2.set_yscale('log')
ax2.plot(history.history['mean_squared_error'], color=color)
ax2.tick_params(axis='y', labelcolor=color)
fig.tight_layout()
plt.show()